We present here a core concept of the PySyft library. It is the ability to add new custom tensor types that can provide specific behaviors such as encryption or traceability. This feature makes our library very generic and completely open to new innovations in the field of privacy-preserving machine learning.
We will go through a very simple example which could be the base for a traceability feature that would keep track of the operations performed on the data in a verifiable way. This new tensor type will log all operations executed on tensors of its type. Let's call this type the CustomLoggingTensor.
Authors:
In [1]:
import torch as th
import syft as sy
sy.create_sandbox(globals(), verbose=False)
Let's first recall the notions of Torch and Syft tensors. All the object the end user manipulates are torch tensors. This is of course the case when it's a pure torch tensor (ex: x = th.tensor([1., 2])
), but also when you deal with syft objects, such as the pointer tensor which is a particular case of syft tensor.
In [2]:
ptr = th.tensor([1., 2]).send(bob)
ptr
Out[2]:
The wrapper object you see is actually an empty torch tensor with a child argument which is a PointerTensor:
In [3]:
isinstance(ptr, th.Tensor)
Out[3]:
In [4]:
type(ptr.child)
Out[4]:
This is also true for more complex objects, where you also see this wrapper at the beginning. You can then have multiple Syft or Torch tensors chained through the .child
attribute.
In [5]:
x = th.tensor([1., 2]).fix_prec().share(alice, bob)
x
Out[5]:
In [6]:
x.child
Out[6]:
In [7]:
x.child.child.child
Out[7]:
Recall that the general behaviour is the following: each time a command in called on the top object, it goes down the chain where it can be modified, it is then executed at the bottom and the result is wrapped back to have exactly the some chain structure, to keep the same properties (such as traceability for example).
What we're going to do here is to create our own syft Tensor type that we will be able to put in this chain!
To get started, there isn't much things to do. First, we need to create the tensor class.
This is done in syft/frameworks/torch/tensors/
, choose the folder:
interpreters
if the functionality you want to build will modify the results or functions, ordecorators
if the functionality is just ... decorative.The
interpreters
/decorators
might be removed in the future, in which case just put your tensor insyft/frameworks/torch/tensors/
Here we'll put it in the decorator folder. Choose a simple but explicit name, for now decorators/custom_logging.py
will be sufficient.
Write there the minimal class definition, where our tensor inherits from AbstractTensor
, an abstraction which gives default behaviours to Syft tensors:
In [8]:
from syft.generic.abstract.tensor import AbstractTensor
class CustomLoggingTensor(AbstractTensor):
def __init__(self, **kwargs):
super().__init__(**kwargs)
This was quite fast, wasn't it?
You now need to declare this type in the imports so that you can use it in real. Add it in the files:
- syft/frameworks/torch/tensors/[decorators|interpreters]/__init__.py
- syft/__init__.py
You should now be able to import the tensor type: from syft import CustomLoggingTensor
In the syft/frameworks/torch/hook/hook.py
file:
__init__
: self._hook_syft_tensor_methods(CustomLoggingTensor)
All instances of CustomLoggingTensor now have for example a .add(...)
method. We should now explain how to use it.
In particular we would like that arguments provided as CustomLoggingTensor be unwrapped and replaced with their .child attribute, do go down the chain.
In the syft/generic/frameworks/hook/hook_args.py
file:
type_rule
dict with CustomLoggingTensor: one,
(means that this type of tensors supports (un)wrapping)forward_func
dict with CustomLoggingTensor: lambda i: i.child,
(explains how to unwrap)backward_func
dict with CustomLoggingTensor: lambda i, **kwargs: CustomLoggingTensor(**kwargs).on(i, wrap=False),
(explains how to wrap)Et voilà! You can already do many things with your new tensor!
In [9]:
x = CustomLoggingTensor()
x
Out[9]:
Ok this is not super useful, but it comes with a .on
method which works as follow:
In [10]:
x = th.tensor([1., 2])
x = CustomLoggingTensor().on(x)
x
Out[10]:
.on
simply inserts the tensor node into a tensor chain. As we always need to have a torch tensor at the top of the chain, a wrapper was automatically added.
As this point, if you want to have the behaviour desired, you should make the code changes in the repository: integrating the code in the repository allows you to benefit from the hooking functionalities.
In particular, after re-instantiating the hook, your CustomLoggingTensor
should have the methods a pure torch tensor has.
Make the change and re-run the notebook up to here
This time, we add the sy.
meaning the code is from the repo.
In [11]:
x = th.tensor([1., 2])
x = sy.CustomLoggingTensor().on(x)
x
Out[11]:
You can do computations on this chain such as x * 2
, and for example the call __mul__
made will be forwarded all through the chain down to the last node which is a pure torch tensor, whose value is doubled.
In [12]:
x * 2
Out[12]:
If you correctly obtained (Wrapper)>CustomLoggingTensor>tensor([2., 4.])
, you're all set!
Now that you have defined your own tensor type, you should specify it's behaviour, as by default it won't do anything thing special and will just act passively.
In this part, we will see how to specify custom functionalities. We'll use for the execution parts the already existing LoggingTensor
instead of the CustomLoggingTensor
and highlight which part of code produces which functionalities, so that you can run code in this notebook without reloading the kernel. If you want to practice more, you can report the code changes in the CustomLoggingTensor
class definition and you'll observe the same behaviours (just reload the notebook each time you perform a modification in the library code)
In [13]:
from syft import LoggingTensor
In [14]:
class CustomLoggingTensor(AbstractTensor):
def __init__(self, **kwargs):
super().__init__(**kwargs)
@classmethod
def on_function_call(cls, command):
"""
Override this to perform a specific action for each call of a torch
function with arguments containing syft tensors of the class doing
the overloading
"""
cmd, _, args, kwargs = command
print("Default log", cmd)
In [15]:
x = th.tensor([1., 2])
x = LoggingTensor().on(x)
th.div(x, x)
th.nn.functional.celu(x) # celu is a variant of the activation function relu(x) = max(0, x)
Out[15]:
Note: this
on_function_call
is called byhandle_func_command
which comes from theAbstractTensor
: it explains how to propagate functions down the chain, and in some cases you might also need to change it.
In [16]:
from syft.generic.frameworks.overload import overloaded
You can directly overwrite torch methods like this, where we overload the .add
method so that we first print that it was called and then forward the call to the .child attributes.
In [17]:
class CustomLoggingTensor(AbstractTensor):
def __init__(self, **kwargs):
super().__init__(**kwargs)
@overloaded.method
def add(self, _self, *args, **kwargs):
print("Log method add")
response = _self.add(*args, **kwargs)
return response
Here is an example of how to use the @
overloaded.method
decorator. To see
what this decorator do, just look at the next method manual_add: it does
exactly the same but without the decorator.
Note the subtlety between self
and _self
: you should use _self
and NOT self
. We kept self
because it can hold useful attributes that you might want to access (for example, for the fixed precision tensor it stores the field size)
Here is the version of the add method without the decorator: as you can see it is much more complicated. However you might need sometimes to use it to specify some particular behaviour: so here where to start from if needed!
In [18]:
class CustomLoggingTensor(AbstractTensor):
# [...]
def manual_add(self, *args, **kwargs):
# Replace all syft tensor with their child attribute
new_self, new_args, new_kwargs = syft.generic.frameworks.hook.hook_args.hook_method_args(
"add", self, args, kwargs
)
print("Log method manual_add")
# Send it to the appropriate class and get the response
response = getattr(new_self, "add")(*new_args, **new_kwargs)
# Put back SyftTensor on the tensors found in the response
response = syft.generic.frameworks.hook.hook_args.hook_response(
"add", response, wrap_type=type(self)
)
return response
They behave exactly the same and print a line when called
In [19]:
x = LoggingTensor().on(th.tensor([1., 2]))
print(x)
r = x.add(x)
You might want to try to run r = x.manual_add(x)
but this will fail: if the LoggingTensor which is x.child
had a .manual_add
method, this is not the case for the wrapper x
as torch tensor don't have .manual_add
by default.
In [20]:
class CustomLoggingTensor(AbstractTensor):
# [...]
@staticmethod
@overloaded.module
def torch(module):
"""
We use the @overloaded.module to specify we're writing here
a function which should overload the function with the same
name in the <torch> module
:param module: object which stores the overloading functions
Note that we used the @staticmethod decorator as we're in a
class
"""
def add(x, y):
"""
You can write the function to overload in the most natural
way, so this will be called whenever you call torch.add on
Logging Tensors, and the x and y you get are also Logging
Tensors, so compared to the @overloaded.method, you see
that the @overloaded.module does not hook the arguments.
"""
print("Log function torch.add")
return x + y
# Just register it using the module variable
module.add = add
@overloaded.function
def mul(x, y):
"""
You can also add the @overloaded.function decorator to also
hook arguments, ie all the LoggingTensor are replaced with
their child attribute
"""
print("Log function torch.mul")
return x * y
# Just register it using the module variable
module.mul = mul
# You can also overload functions in submodules!
@overloaded.module
def nn(module):
"""
The syntax is the same, so @overloaded.module handles recursion
Note that we don't need to add the @staticmethod decorator
"""
@overloaded.module
def functional(module):
@overloaded.function
def relu(x):
print("Log function torch.nn.functional.relu")
return x * (x > 0).float()
module.relu = relu
module.functional = functional
# Modules should be registered just like functions
module.nn = nn
Note the diffence between def add
and def mul
: def add
doesn't have @
overloaded.function
which means that the args inside are not unwrapped: there are CustomLoggingTensors, while in def mul
they are unwrapped and replaced by the child attributes, so Torch tensors in our case.
Look how it changes compared to 2.1: the behaviour is not much different but now the functions modified are very precisely targetted:
In [21]:
x = th.tensor([1., 2])
x = LoggingTensor().on(x)
# Default overloading made in 2.1
r = th.div(x, x)
# Targetted overloading made in 2.3
r = th.add(x, x)
Also, compared to 2.1, we changed the function behaviour: for relu for example instead of running the built-in relu we run x * (x > 0)
, even if the output is the same. We could have also called inside the native relu if we wanted, provided that we unwrap the args using @
overloaded.function
, otherwise we would loop indefinitely.
@overloaded.module
def functional(module):
@overloaded.function
def relu(x):
print("Log function torch.nn.functional.relu")
return torch.nn.functional.relu(x)
In [22]:
fp = th.tensor([1., 2]).fix_precision()
print(fp)
print("Field:", fp.child.field)
Just declare them in the __init__
, like for example a log_max_size
:
In [23]:
class CustomLoggingTensor(AbstractTensor):
def __init__(self, log_max_size=64, **kwargs):
super().__init__(**kwargs)
self.log_max_size = log_max_size
# [...]
To make sure this value gets correctly added to the response of an operation, when the chain is rebuilt and that a CustomLoggingTensor is wrapped on top of the result, you should declare get_class_attributes
:
In [24]:
class CustomLoggingTensor(AbstractTensor):
# [...]
def get_class_attributes(self):
"""
Return all elements which defines an instance of a certain class.
"""
return {
'log_max_size': self.log_max_size
}
Last thing we love to do, is to sent tensors across workers!
To do so, you need to add a serializer and a deserializer to the class:
In [25]:
# Add these new imports
import syft
from syft.workers.abstract import AbstractWorker
class CustomLoggingTensor(AbstractTensor):
# [...]
@staticmethod
def simplify(tensor: "CustomLoggingTensor") -> tuple:
"""Takes the attributes of a CustomLoggingTensor and saves them in a tuple.
Args:
tensor: a CustomLoggingTensor.
Returns:
tuple: a tuple holding the unique attributes of the CustomLoggingTensor.
"""
chain = None
if hasattr(tensor, "child"):
chain = syft.serde._simplify(tensor.child)
return (
syft.serde._simplify(tensor.id),
tensor.log_max_size,
syft.serde._simplify(tensor.tags),
syft.serde._simplify(tensor.description),
chain,
)
@staticmethod
def detail(worker: AbstractWorker, tensor_tuple: tuple) -> "CustomLoggingTensor":
"""
This function reconstructs a CustomLoggingTensor given it's attributes in form of a tuple.
Args:
worker: the worker doing the deserialization
tensor_tuple: a tuple holding the attributes of the CustomLoggingTensor
Returns:
CustomLoggingTensor: a CustomLoggingTensor
Examples:
shared_tensor = detail(data)
"""
tensor_id, log_max_size, tags, description, chain = tensor_tuple
tensor = CustomLoggingTensor(
owner=worker,
id=syft.serde._detail(worker, tensor_id),
log_max_size=log_max_size,
tags=syft.serde._detail(worker, tags),
description=syft.serde._detail(worker, description),
)
if chain is not None:
chain = syft.serde._detail(worker, chain)
tensor.child = chain
return tensor
And to declare this new tensor to the ser/deser module: in serde/serde.py
:
OBJ_SIMPLIFIER_AND_DETAILERS
listEveryting should now work correctly:
In [26]:
x = th.tensor([1., 2])
x = sy.LoggingTensor().on(x)
p = x.send(alice)
print(p)
p2 = p + p
x2 = p2.get()
print(x2)
And here you are, you should now understand all the tools we've builded so that you can easily build new tensor types and focus on their behaviour rather than on their integration in the PySyft library.
Congratulations on completing this notebook tutorial! If you enjoyed this and would like to join the movement toward privacy preserving, decentralized ownership of AI and the AI supply chain (data), you can do so in the following ways!
The easiest way to help our community is just by starring the Repos! This helps raise awareness of the cool tools we're building.
The best way to keep up to date on the latest advancements is to join our community! You can do so by filling out the form at http://slack.openmined.org
The best way to contribute to our community is to become a code contributor! At any time you can go to PySyft GitHub Issues page and filter for "Projects". This will show you all the top level Tickets giving an overview of what projects you can join! If you don't want to join a project, but you would like to do a bit of coding, you can also look for more "one off" mini-projects by searching for GitHub issues marked "good first issue".
If you don't have time to contribute to our codebase, but would still like to lend support, you can also become a Backer on our Open Collective. All donations go toward our web hosting and other community expenses such as hackathons and meetups!
In [ ]: